Optimize GPU checkpoint loading by ensuring model transfer before load_state_dict on build method#325
Conversation
…ng state_dict Moved `prefill_model` and `decode_model` to the target device before calling `load_state_dict` to avoid redundant tensor transfers by PyTorch.
|
Let's go it' really good 👍 |
|
can you review this PR? |
|
@nimanikoo Can I ask where you got my name to review this PR? I don't own this repository and a cursory glance doesn't yield any results implying I should be a required reviewer. I'm curious because this is not the only repository where this keeps happening. I do have (limited) oversight, but not to the degree where I should be approving PRs. Thanks. |
|
Thanks for the heads-up! |
While reviewing the code in this repository, I noticed a few areas that could be optimized for efficiency. I decided to make some changes to how the models are loaded onto the GPU before applying their checkpoints. I believe this should have a positive impact on the performance and overall behavior of the code.
Thanks to everyone who contributed to this repo—really appreciate all the hard work that went into it
Summary
This PR ensures that both
prefill_modelanddecode_modelare moved to thetarget device (e.g., GPU) before invoking
load_state_dict.Motivation
Previously, if the models were still on CPU when loading checkpoints, PyTorch
would perform an additional transfer of tensors, causing unnecessary overhead.
By explicitly moving the models to the correct device first, we avoid redundant
transfers and improve checkpoint loading efficiency.
Changes
prefill_model.to(device)before loading its state dict.decode_model.to(device)before loading its state dict.Impact
This reduces unnecessary GPU/CPU transfers during checkpoint loading, which
should result in faster and more efficient model initialization.